!! Copyright (C) 2009,2010,2011,2012  Marco Restelli
!!
!! This file is part of:
!!   FEMilaro -- Finite Element Method toolkit
!!
!! FEMilaro is free software; you can redistribute it and/or modify it
!! under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 3 of the License, or
!! (at your option) any later version.
!!
!! FEMilaro is distributed in the hope that it will be useful, but
!! WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
!! General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with FEMilaro; If not, see <http://www.gnu.org/licenses/>.
!!
!! author: Marco Restelli                   <marco.restelli@gmail.com>

!> \file
!! Simple program to test the linear solvers.
!!
!! \n
!!
!! A small system, defined in sparse format, is solved with various
!! linear solvers.
!!
!! The matrix is stored as \f$M^T\f$ because, given that sparse
!! matrices are stored by columns, this is the format which better
!! suites iterative solvers.
!!
!! To use this program, create a file named
!! <tt>linsolver_test_data.data</tt> with two variables \c mmm (sparse
!! matrix) and \c rhs. The system will be used to test the linear
!! solvers.
!!
!! To create the file, use for instance the following octave
!! instructions:
!! \code{.m}
!! mmm = sparse([
!!      1 2 3 4;
!!      0 2 0 0;
!!     -1 0 1 0;
!!     -3 0 0 0
!!   ]);
!! rhs = [1,2,3,4]';
!! save linsolver_test_data.data mmm rhs
!! \endcode
module linsolver_test_problem

 use mod_kinds, only: &
   wp

 use mod_state_vars, only: &
   c_stv

 use mod_linsolver, only: &
   c_linpb, c_itpb, c_mumpspb, c_umfpackpb, c_pastixpb

 use mod_sparse, only: &
   t_col, transpose, matmul, clear

 implicit none

 private
 
 public :: &
   mmmt, rhs, t_x, t_linpb_mumps, t_linpb_it, t_linpb_umfpack, &
   t_linpb_pastix

 type, extends(c_stv) :: t_x
  real(wp), allocatable :: x(:)
 contains
  procedure, pass(x) :: incr => x_incr
  procedure, pass(x) :: tims => x_tims
  procedure, pass(z) :: copy => x_copy
  procedure, pass(x) :: scal => x_scal
  !> Compiler bug (see comments in \c mod_state_vars)
  procedure, pass(x) :: source => x_source
  procedure, pass(x) :: source_vect => x_source_vect
 end type t_x

 !> See \c mod_mumpsintf for details.
 type, extends(c_mumpspb) :: t_linpb_mumps
 contains
  procedure, pass(s) :: xassign => mumps_xassign
 end type t_linpb_mumps

 !> See \c mod_pastixintf for details.
 type, extends(c_pastixpb) :: t_linpb_pastix
 contains
  procedure, pass(s) :: xassign => pastix_xassign
 end type t_linpb_pastix

 !> See \c mod_umfintf for details.
 type, extends(c_umfpackpb) :: t_linpb_umfpack
 contains
  procedure, pass(s) :: xassign => umf_xassign
 end type t_linpb_umfpack

 !> See \c mod_iterativesolvers_base for details.
 type, extends(c_itpb) :: t_linpb_it
 contains
  procedure, nopass :: pres => nop_res
  procedure, nopass :: pkry => a_times_x
 end type t_linpb_it

 type(t_col), save, target :: mmmt
 real(wp), allocatable, target :: rhs(:)

contains

 subroutine x_incr(x,y)
  class(c_stv), intent(in)    :: y
  class(t_x),   intent(inout) :: x

   select type(y); type is(t_x)
    x%x = x%x + y%x
   end select
 end subroutine x_incr

 subroutine x_tims(x,r)
  real(wp),   intent(in)    :: r
  class(t_x), intent(inout) :: x

   x%x = r*x%x
 end subroutine x_tims

 subroutine x_copy(z,x)
  class(c_stv), intent(in)    :: x
  class(t_x),   intent(inout) :: z

   select type(x); type is(t_x)
    z%x = x%x
   end select
 end subroutine x_copy

 subroutine x_source(y,x)
  class(t_x),                intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y

   allocate(t_x::y)
   select type(y); type is(t_x)
    allocate(y%x(size(x%x)))
    y%x = x%x
   end select
 end subroutine x_source
 subroutine x_source_vect(y,x,m)
  integer, intent(in) :: m
  class(t_x),                intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y(:)

  integer :: i

   allocate(t_x::y(m))
   select type(y); type is(t_x)
    do i=1,m
      allocate(y(i)%x(size(x%x)))
      y(i)%x = x%x
    enddo
   end select
 end subroutine x_source_vect

 subroutine mumps_xassign(x,s,mumps_x)
  real(wp),             intent(in)    :: mumps_x(:)
  class(t_linpb_mumps), intent(inout) :: s
  class(c_stv),         intent(inout) :: x

   select type(x); type is(t_x)
    x%x = mumps_x
   end select
 end subroutine mumps_xassign

 subroutine pastix_xassign(x,s,pastix_x)
  real(wp),             intent(in)    :: pastix_x(:)
  class(t_linpb_pastix), intent(inout) :: s
  class(c_stv),         intent(inout) :: x

   select type(x); type is(t_x)
    x%x = pastix_x
   end select
 end subroutine pastix_xassign

 subroutine umf_xassign(x,s,umfpack_x)
  real(wp),               intent(in)    :: umfpack_x(:)
  class(t_linpb_umfpack), intent(inout) :: s
  class(c_stv),           intent(inout) :: x

   select type(x); type is(t_x)
    x%x = real(umfpack_x,wp)
   end select
 end subroutine umf_xassign

 ! Residual (no preconditioning)
 subroutine nop_res(r,x)
  class(c_stv), intent(in) :: x
  class(c_stv), intent(inout) :: r

   select type(x); type is(t_x); select type(r); type is(t_x)
    r%x = rhs - matmul(x%x,mmmt)
   end select; end select
 end subroutine nop_res

 ! Krylov vector (no preconditioning)
 subroutine a_times_x(r,x)
  class(c_stv), intent(in) :: x
  class(c_stv), intent(inout) :: r

   select type(x); type is(t_x); select type(r); type is(t_x)
    r%x = matmul(x%x,mmmt)
   end select; end select
 end subroutine a_times_x

 ! Scalar product
 function x_scal(x,y) result(s)
  class(t_x),   intent(in) :: x
  class(c_stv), intent(in) :: y
  real(wp) :: s

   select type(y); type is(t_x)
    s = dot_product(x%x,y%x)
   end select
 end function x_scal
end module linsolver_test_problem


program linsolver_test

 use mod_messages, only: &
   mod_messages_constructor, &
   mod_messages_destructor,  &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   mod_kinds_constructor, &
   mod_kinds_destructor,  &
   wp

 use mod_output_control, only: &
   mod_output_control_constructor, &
   mod_output_control_destructor,  &
   elapsed_format, base_name

 use mod_mpi_utils, only: &
   mod_mpi_utils_constructor, &
   mod_mpi_utils_destructor,  &
   mpi_init, mpi_thread_single, mpi_thread_multiple, &
   mpi_finalize, &
   mpi_comm_world, mpi_comm_rank

 use mod_sparse, only: &
   mod_sparse_constructor, &
   mod_sparse_destructor,  &
   t_col, transpose, matmul, clear

 use mod_octave_io, only: &
   mod_octave_io_constructor, &
   mod_octave_io_destructor,  &
   write_octave, read_octave_al

 use mod_octave_io_sparse, only: &
   mod_octave_io_sparse_constructor, &
   mod_octave_io_sparse_destructor,  &
   write_octave, read_octave

 use mod_state_vars, only: &
   mod_state_vars_constructor, &
   mod_state_vars_destructor,  &
   c_stv

 use mod_linsolver, only: &
   mod_linsolver_constructor, &
   mod_linsolver_destructor,  &
   c_linpb, c_itpb, c_mumpspb, &
   gmres

 use linsolver_test_problem, only: &
   mmmt, rhs, t_x, t_linpb_mumps, t_linpb_it, t_linpb_umfpack, &
   t_linpb_pastix

 implicit none

 logical, parameter :: & ! select here the solvers to test
   test_mumps   = .false., &
   test_gmres   = .true.,  &
   test_umfpack = .true.,  &
   test_pastix  = .true.

 integer :: i, ierr
 character(len=1000) :: message

 ! linear system
 type(t_x) :: x
 integer, allocatable, target :: gij(:)
 class(c_linpb), allocatable :: linpb

  ! Note: PaStiX uses scotch, which in turn *might* require that MPI
  ! is initialized with thread support
  !call mpi_init(ierr)
  call mpi_init_thread(mpi_thread_multiple,i,ierr)

  call mod_messages_constructor()
  call mod_kinds_constructor()
  call mod_output_control_constructor('linsolver_test_output')
  call mod_mpi_utils_constructor()
  call mod_sparse_constructor()
  call mod_octave_io_constructor()
  call mod_octave_io_sparse_constructor()
  call mod_state_vars_constructor()
  call mod_linsolver_constructor()

  ! Define the linear system
  open(25,file='linsolver_test_data.data',     &
    status='old',action='read',form='formatted')
  call read_octave(mmmt,'mmm',25)
  mmmt = transpose(mmmt)
  call read_octave_al(rhs,'rhs',25)
  close(25,iostat=ierr)
  allocate(gij(0:mmmt%m-1)); gij = (/ (i, i=1,mmmt%m) /)-1
  allocate(x%x(mmmt%m))

  !---------------------------------------------------------------------

  ! MUMPS solver
  allocate(t_linpb_mumps::linpb)
  if(linpb%working_implementation()) then
    x%x = 0.0_wp ! initial guess (not used by MUMPS)
    select type(linpb); type is(t_linpb_mumps)
      linpb%distributed    = .true.
      linpb%poo            = 3 ! SCOTCH
      linpb%transposed_mat = .true.
      linpb%gn             = mmmt%m
      linpb%m              => mmmt
      linpb%rhs            => rhs
      linpb%gij            => gij
      linpb%mpi_comm       = mpi_comm_world
    end select
    call linpb%factor('analysis')
    call linpb%factor('factorization')
    call linpb%solve(x)
    call linpb%clean()

    write(message,'(a,e23.15)') &
      "Done MUMPS; residual ", norm2( rhs - matmul(x%x,mmmt) )
    call info('linsolver_test','',message)
    open(25,file='linsolver_test_data_MUMPS.results', &
      status='replace',action='write',form='formatted')
    call write_octave(transpose(mmmt),'mmm',25)
    call write_octave(rhs,'c','rhs',25)
    call write_octave(x%x,'c', 'x' ,25)
    close(25,iostat=ierr)
  endif
  deallocate(linpb)

  !---------------------------------------------------------------------

  ! GMRES solver
  allocate(t_linpb_it::linpb)
  if(linpb%working_implementation()) then
    x%x = 0.0_wp ! initial guess
    select type(linpb); type is(t_linpb_it)
      linpb%abstol    = .true.
      linpb%tolerance = 1.0e-8_wp
      linpb%nmax      = 5  ! size of the Krylov space
      linpb%rmax      = 10 ! restarts
      linpb%solver    => gmres
      linpb%mpi_comm  = mpi_comm_world
      call mpi_comm_rank(linpb%mpi_comm,linpb%mpi_id,ierr)
    end select
    call linpb%factor('analysis')
    call linpb%factor('factorization')
    call linpb%solve(x)
    call linpb%clean()

    write(message,'(a,e23.15)') &
      "Done GMRES; residual ", norm2( rhs - matmul(x%x,mmmt) )
    call info('linsolver_test','',message)
    open(25,file='linsolver_test_data_GMRES.results', &
      status='replace',action='write',form='formatted')
    call write_octave(transpose(mmmt),'mmm',25)
    call write_octave(rhs,'c','rhs',25)
    call write_octave(x%x,'c', 'x' ,25)
    close(25,iostat=ierr)
  endif
  deallocate(linpb)

  !---------------------------------------------------------------------

  ! UMFPack solver
  allocate(t_linpb_umfpack::linpb)
  if(linpb%working_implementation()) then
    x%x = 0.0_wp ! initial guess (not used)
    select type(linpb); type is(t_linpb_umfpack)
      linpb%print_level    = 5
      linpb%transposed_mat = .true.
      linpb%m              => mmmt
      linpb%rhs            => rhs
    end select
    call linpb%factor('analysis')
    call linpb%factor('factorization')
    call linpb%solve(x)
    call linpb%clean()

    write(message,'(a,e23.15)') &
      "Done UMFPack; residual ", norm2( rhs - matmul(x%x,mmmt) )
    call info('linsolver_test','',message)
    open(25,file='linsolver_test_data_UMFPack.results', &
      status='replace',action='write',form='formatted')
    call write_octave(transpose(mmmt),'mmm',25)
    call write_octave(rhs,'c','rhs',25)
    call write_octave(x%x,'c', 'x' ,25)
    close(25,iostat=ierr)
  endif
  deallocate(linpb)

  !---------------------------------------------------------------------

  ! PaStiX solver
  allocate(t_linpb_pastix::linpb)
  if(linpb%working_implementation()) then
    x%x = 0.0_wp ! initial guess (not used)
    select type(linpb); type is(t_linpb_pastix)
      linpb%transposed_mat = .true.
      linpb%gn             = mmmt%m
      linpb%m              => mmmt
      linpb%rhs            => rhs
      linpb%gij            => gij
      linpb%mpi_comm       = mpi_comm_world
    end select
    call linpb%factor('analysis')
    call linpb%factor('factorization')
    call linpb%solve(x)
    call linpb%clean()

    write(message,'(a,e23.15)') &
      "Done PaStiX; residual ", norm2( rhs - matmul(x%x,mmmt) )
    call info('linsolver_test','',message)
    open(25,file='linsolver_test_data_PaStiX.results', &
      status='replace',action='write',form='formatted')
    call write_octave(transpose(mmmt),'mmm',25)
    call write_octave(rhs,'c','rhs',25)
    call write_octave(x%x,'c', 'x' ,25)
    close(25,iostat=ierr)
  endif
  deallocate(linpb)

  !---------------------------------------------------------------------

  call clear(mmmt)
  deallocate(rhs,gij,x%x)

  call mod_linsolver_destructor()
  call mod_state_vars_destructor()
  call mod_octave_io_sparse_destructor()
  call mod_octave_io_destructor()
  call mod_sparse_destructor()
  call mod_mpi_utils_destructor()
  call mod_output_control_destructor()
  call mod_kinds_destructor()
  call mod_messages_destructor()
  call mpi_finalize(ierr)

contains

 pure function norm2(x) result(n)
  real(wp), intent(in) :: x(:)
  real(wp) :: n
   n = sqrt(sum(x**2))
 end function norm2

end program linsolver_test

